import argparse
import multiprocessing
import os
import random
import sys
import time
from multiprocessing import Pool

import torch
import torch.nn as nn
from joblib import Parallel, delayed
from sklearn.linear_model import LogisticRegression
from torch import Tensor
from torch.utils.data import DataLoader

from controlsnr import find_a_given_snr
# from data_generator_dcsbm import Generator, simple_collate_fn
from data_generator_lsm import Generator,simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from models import GNN_multiclass, GNN_multiclass_second_period
from spectral_clustering import spectral_clustering_adj
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping


def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)  # 行缓冲
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device):
    """
    将 ckpt_path 的“最优模型”加载进已有的 model 实例。
    兼容以下两类 checkpoint：
    - torch.save(model) / torch.load() 得到的整个模型
    - torch.save({'model_state': state_dict, ...})
    """
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        model.load_state_dict(ckpt['model_state'], strict=False)
        return model
    elif hasattr(ckpt, 'state_dict'):  # 整个模型对象
        # 如果架构一致，也可以直接用；更稳妥的是取其 state_dict
        try:
            model.load_state_dict(ckpt.state_dict(), strict=False)
            return model
        except Exception:
            # 直接返回已加载的整个模型对象（注意再 .to(device)）
            ckpt = ckpt.to(device)
            return ckpt
    else:
        # 直接当作 state_dict 来读
        model.load_state_dict(ckpt, strict=False)
        return model

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6)) #原来是6000，
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(1))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1)) #原来是1000
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='train')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default=10)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
    # torch.cuda.manual_seed(0)
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor
    # torch.manual_seed(1)

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()

template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'



class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}


# ✅ 正确的 imports
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors
from scipy.sparse import issparse
from scipy.sparse.linalg import eigsh as sparse_eigsh
from numpy.linalg import eigh as dense_eigh



def extract_spectral_features(
    adj,
    k: int,
    *,
    row_norm: bool = True,
    use_abs: bool = True,   # True: 按 |λ| 选前 k；False: 按最大正特征值选前 k
):
    """
    Perform spectral decomposition of the adjacency matrix A and return the first k eigenvectors (N, k) as features.

    parameter
    ----
    adj: (N,N) Sparse or dense adjacency matrix (should be symmetric real matrix)
    k : Number of eigenvectors taken
    row_norm: Whether to do L2 normalization for each row (True by default, often more stable)
    use_abs : Whether to press |λ| Choose (True by default). If False, take the direction of the largest positive eigenvalue

    return
    ----
    U : np.ndarray, shape (N, k), the first k of A (press |λ| or by the maximum positive value) eigenvector
    """
    if issparse(adj):
        adj = adj.astype(np.float64)  # ✅ 修复 int 类型报错
        # 稀疏：直接用 ARPACK；which="LM" -> largest magnitude
        k_eff = max(1, min(k, adj.shape[0]-2))  # ARPACK 要求 k < N
        if use_abs:
            w, U = sparse_eigsh(adj, k=k_eff, which="LM")  # |λ| 最大
        else:
            w, U = sparse_eigsh(adj, k=k_eff, which="LA")  # 最大代数值（偏正端）
        # 为了稳定，按选择准则再排一下列顺序
        order = np.argsort(np.abs(w) if use_abs else w)[::-1]
        U = U[:, order[:k]]

    else:
        A = np.asarray(adj, dtype=np.float64)
        # 稠密：先全谱（升序），再手动筛选
        w_all, V_all = dense_eigh(A)  # 升序
        if use_abs:
            idx = np.argsort(np.abs(w_all))[-k:]          # 取 |λ| 最大的 k
            idx = idx[np.argsort(np.abs(w_all[idx]))[::-1]]
        else:
            idx = np.argsort(w_all)[-k:]                  # 取最大的 k 个（偏正端）
            idx = idx[::-1]
        U = V_all[:, idx]

    # 列符号对齐（避免跑不同次整体翻号）
    col_mean = U.mean(axis=0, keepdims=True)
    U *= np.where(col_mean >= 0, 1.0, -1.0)

    # 可选行归一化（把每个节点向量长度拉到 1 附近，消掉幅值差异）
    if row_norm:
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)

    return U


import joblib

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

def _load_and_feat(path, k):
    data = np.load(path)
    adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']),
                     shape=tuple(data['adj_shape']))
    labels = np.asarray(data['labels']).ravel().astype(int)
    feats = extract_spectral_features(adj, k)   # ⬅️ 改成稀疏 eigsh 的版本
    return feats, labels

def train_logistic_regression(gen, k, save_path="./models/node_info_added_lr_model.pkl", n_jobs=os.cpu_count()//2):
    # 如果模型已存在 → 直接加载
    if os.path.exists(save_path):
        print(f"检测到已存在的模型文件 {save_path}，直接加载。")
        return joblib.load(save_path)

    # 并行提取特征
    results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(_load_and_feat)(fp, k) for fp in gen.data_train
    )
    all_features, all_labels = zip(*results)

    X_train = np.vstack(all_features).astype(np.float32, copy=False)
    y_train = np.concatenate(all_labels)

    # 训练逻辑回归模型
    lr_model = LogisticRegression(
        multi_class='multinomial',
        solver='saga',  # 支持大规模 & 稀疏数据
        max_iter=1000,
        random_state=42
    )

    lr_model.fit(X_train, y_train)

    # 保存模型
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    joblib.dump(lr_model, save_path)
    print(f"模型已保存到 {save_path}")

    return lr_model


def evaluate_using_logitic(W, true_labels ,lr_model, k):
    """
    :param W: The graph we want to test
    :param true_labels: The true labels
    :param lr_model: The well-trained logistic regression model
    :param k: The number of communities
    :return: The average accuracy
    """
    W = W.squeeze(0)

    # 提取谱特征
    features = extract_spectral_features(W, k)

    # 预测标签
    pred_labels = lr_model.predict(features)

    # 计算准确率
    acc_logistic, logistic_best_matched_pred, ari_logistic, nmi_logistic = compute_acc_ari_nmi(pred_labels, true_labels, k)
    # best_acc, best_pred, ari, nmi

    # Local refinement
    logistic_best_matched_pred_refined = local_refinement_by_neighbors(W, logistic_best_matched_pred, k)
    acc_logistic_refined, best_pred, ari_logistic_refined, nmi_logistic_refined = compute_acc_ari_nmi(logistic_best_matched_pred_refined, true_labels, k)

    return acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            # 尝试分配临时张量测试显存
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")


device = get_available_device()




def test_single_first_period(gnn_first_period, lr_model, gen, n_classes, args, iter, mode='balanced', class_sizes=None):
    start = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gnn_first_period.train()

    # 选择模式
    if mode == 'imbalanced':
        if random.random() < 0.5:
            random_class_sizes = class_sizes[::-1]  # 翻转成 [950, 50]
        else:
            random_class_sizes = class_sizes

        W, true_labels, eigvecs_top = gen.imbalanced_sample_otf_single(random_class_sizes, is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)

    res_spectral = spectral_clustering_adj(
        W, n_classes, true_labels,
        normalized=True,
        run_all=True,
        random_state=0
    )
    # (A, k, true_labels, normalized: bool = False, *, run_all: bool = False, random_state: int = 0):

    acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined\
        = evaluate_using_logitic(W, true_labels, lr_model, n_classes)

    # GNN 输入
    WW, x = get_gnn_inputs(W, args.J)
    WW, x = WW.to(device), x.to(device)

    # 禁用梯度计算
    with torch.no_grad():
        pred_single_first = gnn_first_period(WW.type(dtype), x.type(dtype))

    # --- 3. 计算邻接矩阵特征向量 ---
    W_np = W.squeeze(0).cpu().numpy() if isinstance(W, Tensor) else W.squeeze(0)

    W_for_eig = (W_np + W_np.T) / 2  # 确保对称
    eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
    eigvals_W, eigvecs_W = np.real(eigvals_W), np.real(eigvecs_W)
    adjacency_eigvecs = eigvecs_W[:, np.argsort(eigvals_W)[-n_classes:][::-1]]
    adjacency_eigvecs /= np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True)

    # 中间层特征
    penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)
    penultimate_features /= np.linalg.norm(penultimate_features, axis=0, keepdims=True)

    # 第二阶段 loss 和 acc
    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
    # acc_mean, best_matched_preds, ari_mean, nmi_mean
    gnn_pred_label_first = best_matched_pred
    # Local refinement
    gnn_refined = local_refinement_by_neighbors(W_np, gnn_pred_label_first, n_classes)

    acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)


    N = true_labels.shape[1]
    elapsed = time.time() - start

    if(torch.cuda.is_available()):
        loss_value = float(loss_test_first.data.cpu().numpy())
    else:
        loss_value = float(loss_test_first.data.numpy())

    info = ['iter', 'avg loss', 'avg acc', 'edge_density',
            'noise', 'model', 'elapsed']
    out = [iter, loss_value, acc_gnn_first, args.edge_density,
           args.noise, 'GNN', elapsed]
    print(template1.format(*info))
    print(template2.format(*out))

    del WW
    del x

    # 构造 Excel 表
    data = {
        'True_Label': true_labels.squeeze(0).cpu().numpy(),
        'Pred_Label_First': gnn_pred_label_first.reshape(-1),
        'Loss_First': [float(loss_test_first)] * N,
        'Acc_First': [float(acc_gnn_first)] * N,
    }

    for i in range(2 * n_classes):
        data[f'penultimate_GNN_Feature{i + 1}'] = penultimate_features[:, i]
    for i in range(n_classes):
        data[f'eigvecs_top{i + 1}'] = eigvecs_top[:, i]
    for i in range(n_classes):
        data['Adj_EigVecs_Top' + str(i + 1) + ''] = adjacency_eigvecs[:, i]

    df = pd.DataFrame(data)

    # 写入 Excel
    root_folder = "penultimate_GNN_Feature"
    subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
    output_filename = (
        f"first_gnn_classesizes={class_sizes}_p={gen.p_SBM}_q={gen.q_SBM}_"
        f"j={args.J}_nlyr={args.num_layers}.xlsx"
    )
    output_path = os.path.join(root_folder, subfolder_name, output_filename)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    if iter < 10:
        if iter == 0:
            df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)
        else:
            with pd.ExcelWriter(output_path, mode='a', engine='openpyxl') as writer:
                df.to_excel(writer, sheet_name=f'Iteration_{iter}', index=False)

    # 拆三种谱方法的结果
    acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
    acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']

    acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
    acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']

    acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
    acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']

    metrics = {
        "gnn": {
            "acc": float(acc_gnn_first),
            "ari": float(ari_gnn_first),
            "nmi": float(nmi_gnn_first),
        },
        "gnn_refined": {
            "acc": float(acc_gnn_refined),
            "ari": float(ari_gnn_refined),
            "nmi": float(nmi_gnn_refined),
        },

        # ——三种谱方法（原始）——
        "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
        "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},

        "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
        "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},

        # ——三种谱方法（refined）——
        "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
        "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},

        # ——你原来的逻辑回归部分——
        "logistic": {
            "acc": float(acc_logistic),
            "ari": float(ari_logistic),
            "nmi": float(nmi_logistic),
        },
        "logistic_refined": {
            "acc": float(acc_logistic_refined),
            "ari": float(ari_logistic_refined),
            "nmi": float(nmi_logistic_refined),
        },
    }

    return float(loss_test_first), metrics



METHODS = ["gnn", "gnn_refined",

           "spectral_normalized", "spectral_normalized_refined",

           "spectral_unnormalized", "spectral_unnormalized_refined",

           "spectral_adjacency","spectral_adjacency_refined",

           "logistic", "logistic_refined"]


def test_first_period(
    gnn_first_period,
    lr_model,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
):
    """
    运行 iters 次评测，返回三行 dict（ACC/ARI/NMI 各一行）。
    每行的列：元信息 + 各方法的均值（列为 METHODS）
    不写文件，交给 append_rows_to_excel 处理。
    依赖：test_single_first_period(...) -> (loss, metrics) 且 metrics 结构：
        metrics = {
          "gnn": {"acc": ..., "ari": ..., "nmi": ...},
          "spectral": {...}, "spectral_refined": {...},
          "logistic": {...}, "logistic_refined": {...}
        }
    """
    if iters is None:
        iters = args.num_examples_test

    gnn_first_period.eval()

    # 收集器： per-metric per-method 序列
    buckets = {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS}

    for it in range(iters):
        loss_val, metrics = test_single_first_period(
            gnn_first_period=gnn_first_period,
            lr_model=lr_model,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            mode=mode,
            class_sizes=class_sizes,
        )

        for meth in METHODS:
            md = metrics.get(meth, {})
            for k in ("acc", "ari", "nmi"):
                v = md.get(k, np.nan)
                buckets[meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    # 只取均值
    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    # 元信息
    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    k = n_classes
    snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_SBM": float(gen.p_SBM),
        "q_SBM": float(gen.q_SBM),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(snr) if np.isfinite(snr) else np.nan,
    }

    # 组装三行（ACC/ARI/NMI）
    row_acc = {**meta}
    row_ari = {**meta}
    row_nmi = {**meta}
    for meth in METHODS:
        row_acc[meth] = mean(buckets[meth]["acc"])
        row_ari[meth] = mean(buckets[meth]["ari"])
        row_nmi[meth] = mean(buckets[meth]["nmi"])

    return row_acc, row_ari, row_nmi



from pathlib import Path
import pandas as pd
import numpy as np

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    """
    把三行分别追加到 Excel 的 ACC / ARI / NMI 三个 sheet。
    - 首次创建文件会带表头；存在则读取合并再写回（自动对齐列）。
    - extra_info（如 {"class_sizes":[100,900], "snr":0.5, "total_ab":10}）会并入到三张表的该行中。
      为便于筛选，class_sizes 若为 list/array 会写成 '100-900' 字符串。
    """
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        # 合并 extra_info
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))

        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                # 对齐列
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                # 指定的 sheet 不存在，则直接创建
                df_out = df_new
            # 追加/替换写回
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            # 新建文件
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)



def test_first_period_wrapper(args_tuple):
    gnn_first_period, lr_model ,class_sizes, snr, gen, logN_div_N, total_ab= args_tuple

    a, b = find_a_given_snr(snr, args.n_classes, total_ab)
    p_SBM = round(a * logN_div_N, 4)
    q_SBM = round(b * logN_div_N, 4)

    # 每个进程用自己的 gen 副本（你可以用 copy 或重新初始化）
    gen_local = gen.copy()  # 确保你实现了 Generator.copy() 方法
    gen_local.p_SBM = p_SBM
    gen_local.q_SBM = q_SBM

    print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")

    row_acc, row_ari, row_nmi = test_first_period(
        gnn_first_period=gnn_first_period,
        lr_model=lr_model,
        n_classes=args.n_classes,
        gen=gen_local,
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,   # 'balanced' 或 'imbalanced'
        class_sizes=class_sizes,
    )

    return row_acc, row_ari, row_nmi


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val

        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM

        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes

        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        # 1. 创建总模型文件夹
        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)

        # 2. 创建子目录（按 n_classes 分类）
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        # 3. 构造保存路径
        filename_first = (
            f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        )
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = (
            f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        )

        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            ################################################################################################################
            # Here is the train period, prepare the dataloader we need to train
            gen.prepare_data()

            # 1. 准备并行线程数
            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            # 3. 创建对应的 DataLoader
            train_loader = DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=num_workers,
                collate_fn=simple_collate_fn

            )

            print(
                f"[Stage 1] Training GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")

            # === 阶段一 ===
            if os.path.exists(path_first):
                print(f"[Stage 1] Detect that there is already an 'optimal weight' and load it directly: {path_first}")
                # 无论之前是否训练过，都确保 gnn_first_period 存在并载入“最优”
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                      n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                # 初始化并训练第一阶段
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                      n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=100, patience=3, save_path=path_first, filename=filename_first
                )

                print(f"[Phase 1] Backload from optimal weight to memory: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            maybe_freeze(gnn_first_period, freeze=True)

            # === 阶段二 ===
            print(
                f"[Stage 2] training GNN：n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")

            if os.path.exists(path_second):
                print(f"[Stage 2] Detect the existing model file and load it directly: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device, weights_only=False)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 2, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=100, patience=3, save_path=path_second, filename=filename_second
                )

                print(f"[Phase 2] Save the model to {path_second}")

    except Exception as e:

        import traceback

        traceback.print_exc()